import numpy as np
import torch


def cal_selection(not_avail_idx: list, c_loss: torch.Tensor, n_p: int,
                  arg_max: torch.Tensor = None, diversity=False, max_p=5) -> list:
    """
    Constrained Active learning strategy.
    We take n elements which are the one that most violates the constraints
    and are among available idx

    :param not_avail_idx: unavailable data (already selected)
    :param c_loss: constraint violation calculated for each point
    :param n_p: number of points to select
    :param arg_max: the number of the most violated rule for each point
    :param diversity: whether to select points based also on their diversity
    :return list of the selected idx
    """
    assert not diversity or arg_max is not None, "In case the diversity strategy " \
                                                 "is applied, arg max need to be passed"

    c_loss = c_loss.clone().detach()
    c_loss[torch.as_tensor(not_avail_idx)] = -1
    cal_idx = torch.argsort(c_loss, descending=True)
    if diversity:
        selected_idx = []
        arg_loss_dict = {}
        for index in cal_idx:
            arg_loss = arg_max[index].item()
            if arg_loss in arg_loss_dict:
                if arg_loss_dict[arg_loss] == max_p:
                    continue
                else:
                    arg_loss_dict[arg_loss] += 1
            else:
                arg_loss_dict[arg_loss] = 1
            if len(selected_idx) == n_p:
                break
            selected_idx.append(index)
        assert len(selected_idx) == n_p, "Error in the diversity " \
                                         "selection operation"
        return selected_idx

    return cal_idx.tolist()[:n_p]


def random_selection(avail_idx: list, n_p: int) -> list:
    """
    Random Active learning strategy
    Theoretically the worst possible strategy. At each iteration
    we just take n elements randomly

    :param avail_idx: available data (not already selected)
    :param n_p: number of points to select
    :return selected idx
    """
    random_idx = np.random.choice(avail_idx, n_p).tolist()
    return random_idx


def supervised_selection(not_avail_idx: list, s_loss: torch.Tensor, n_p: int) -> list:
    """
    Supervised Active learning strategy
    Possibly an upper bound to a learning strategy efficacy (fake, obviously).
    We directly select the point which mostly violates the supervision loss.

    :param not_avail_idx: unavailable data (already selected)
    :param s_loss: supervision violation calculated for each point
    :param n_p: number of points to select
    :return: selected idx
    """
    s_loss = s_loss.clone().detach()
    s_loss[torch.as_tensor(not_avail_idx)] = -1
    sup_idx = torch.argsort(s_loss, descending=True).tolist()[:n_p]
    return sup_idx


def uncertainty_loss(p: torch.Tensor):
    """
    We define as uncertainty a metric function for calculating the
    proximity to the boundary (predictions = 0.5).
    In order to be a proper metric function we take the opposite of
    the distance from the boundary mapped into [0,1]
    uncertainty = 1 - 2 * ||preds - 0.5||

    :param p: predictions of the network
    :return: uncertainty measure
    """
    distance = torch.abs(p - 0.5)
    if len(p.shape) > 1:
        distance = distance.mean(dim=1)
    uncertainty = 1 - 2 * distance
    return uncertainty


def uncertainty_selection(not_avail_idx: list, u_loss: torch.Tensor, n_p: int) -> list:
    """
    Uncertainty Active learning strategy
    We take n elements which are the ones on which the networks is
    mostly uncertain (i.e. the points lying closer to the decision boundaries).

    :param not_avail_idx: unavailable data (already selected)
    :param s_loss: supervision violation calculated for each point
    :param n_p: number of points to select
    :return selected idx
    """
    u_loss[torch.as_tensor(not_avail_idx)] = -1
    uncertain_idx = torch.argsort(u_loss, descending=True).tolist()[:n_p]
    return uncertain_idx


SUPERVISED = "supervised"
RANDOM = "random"
CAL = "constrained"
UNCERTAIN = "uncertainty"

